from data.cow import CowDataset
from data.llff import LlffDataset
from data.synthetic import BlenderDataset

def load_dataset(data_conf, split='train', epoch_len=6000):
    if data_conf.dataset_type == 'cow':
        if split == 'train':
            return CowDataset(
                data_conf.num_training_views,
                epoch_len=epoch_len,
            )
        elif split == 'val':
            return CowDataset(
                data_conf.num_validation_views,
                epoch_len=data_conf.num_validation_views,
            )
    elif data_conf.dataset_type == 'llff':
        return LlffDataset(
            split=split, epoch_len=epoch_len, **data_conf.loader_args
        )
    elif data_conf.dataset_type == 'synthetic':
        return BlenderDataset(
            split=split, epoch_len=epoch_len, **data_conf.loader_args
        )
    else:
        raise NotImplementedError
